# ======================================================================
# Copyright CERFACS (January 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

"""This module implements some subroutines useful to implement Ahokas' circuits.

The quantum routines implemented in this module are all used as subroutines to
implement the circuits detailed in
:cite:`2004-ahokas-graeme-robert-improved-algorithms-for-approximate-quantum-fourier-transforms-and-sparse-hamiltonian-simulations`.

TODOs
=====

1. Optimise the implementation of expZZFt and expZZF2rt. For the moment there are too
   much gates used and too much controls. Maybe add an ancilla qubit.
2. Optimise the :math:`R_{zz}` gate so that the gate alone is the identity but it
   becomes a real gate when controlled. **MIGHT BREAK SOME TESTS**.
"""

import numpy
from qat.lang.AQASM.gates import X, H, CNOT, CCNOT, RX, RZ, PH
from qat.lang.AQASM.routines import QRoutine
from qat.lang.AQASM.misc import build_gate


@build_gate("W", [], arity=2)
def W() -> QRoutine:
    """W gate as used in the master thesis of Graeme Ahokas.

    See https://quantumcomputing.stackexchange.com/a/5059/1386 for the
    circuit decomposition.
    """
    rout = QRoutine(arity=2)
    rout.apply(CNOT, 0, 1)
    rout.apply(H.ctrl(), 1, 0)
    rout.apply(CNOT, 0, 1)
    return rout


@build_gate("G", [], arity=1)
def G() -> QRoutine:
    """G gate as used in the master thesis of Graeme Ahokas.

    See https://quantumcomputing.stackexchange.com/a/5059/1386 for the circuit
    decomposition.
    """
    rout = QRoutine(arity=1)
    rout.apply(RX(numpy.pi / 2), 0)
    rout.apply(rzz(-numpy.pi / 2), 0)
    return rout


@build_gate("Rzz", [float], arity=1)
def rzz(theta: float) -> QRoutine:
    """Phase gate."""
    rout = QRoutine(arity=1)
    rout.apply(PH(theta), 0)
    rout.apply(X, 0)
    rout.apply(PH(theta), 0)
    rout.apply(X, 0)
    return rout


@build_gate("toffoli_10", [], arity=3)
def toffoli_10() -> QRoutine:
    """
    Toffoli gate that is positively controlled by the first qubit and negatively
    controlled by the second qubit.
    """
    rout = QRoutine(arity=3)
    rout.protected_apply(X, 1)
    rout.apply(CCNOT, 0, 1, 2)
    rout.protected_apply(X, 1)
    return rout


@build_gate("A", [int], arity=lambda n: 2 * n + 1)
def A(n: int) -> QRoutine:
    """The A gate as used in the master thesis of Graeme Ahokas."""
    rout = QRoutine()
    x = rout.new_wires(n)
    m = rout.new_wires(n)
    p = rout.new_wires(1)

    for i in range(n):
        rout.apply(W(), x[i], m[i])
    for i in range(n):
        rout.apply(toffoli_10(), x[i], m[i], p)
    return rout


@build_gate("two_qubits_parity_gate", [], arity=3)
def Two_qubits_parity_gate() -> QRoutine:
    """Quantum gate used to store the parity of 2 qubits in a third one.

    The parity of two qubits is 0 if the two qubits are in the same state, else it is 1.
    The third qubit is **required** to be in the :math:`\ket{0}` state.
    """
    rout = QRoutine()
    rout.apply(CNOT, 0, 2)
    rout.apply(CNOT, 1, 2)
    return rout


@build_gate("exp_ZFt", [int, float], arity=lambda n, t: n + 1)
def exp_ZFt(n: int, time: float) -> QRoutine:
    r"""Perform the e^{iZ\otimes Ft} operator.

    The returned procedure takes n+1 qubits as parameters: ::

        |  |p>  |        |w>        |
        |   .   |   .   .   .   .   |
        |   0   |   1           n   |

    - The first qubit (called p) should be initialised with the parity count (see the
        article).
    - The n following qubits represents the weights of the Hamiltonian matrix in
        **little** endian (i.e. the qubit n°1 (0 being p) is the less significant one
        and the qubit n°(n+1) is the most significant one.
    :param n: Number of qubits used to encode an integer on :math:`\ket{w}`.
    :param time: Time of evolution for the hermitian operator ZF.
    """

    def cPH(theta: float) -> QRoutine:
        return PH(theta).ctrl()

    routine = QRoutine()
    p = routine.new_wires(1)
    w = routine.new_wires(n)

    routine.protected_apply(X, p)
    for j in range(n):
        routine.apply(cPH(-(2 ** j) * time), p, w[n - 1 - j])
    routine.protected_apply(X.dag(), p)

    for j in range(n):
        routine.apply(cPH(2 ** j * time), p, w[n - 1 - j])

    return routine


@build_gate("exp_ZF2rt", [int, float], arity=lambda r, t: 3 * r + 1)
def exp_ZF2r_t(r: int, time: float) -> QRoutine:
    r"""Perform the e^{iZ\otimes F/2^{2r} t} operator.

    The returned procedure takes 3*r+1 qubits as parameters: ::

        |  |p>  |        |x>        |
        |   .   |   .   .   .   .   |
        |   0   |   1          3*r  |

    - The first qubit (called p) should be initialised with the parity count (see the
        article).
    - The 3*r following qubits represents the weights of the Hamiltonian matrix in
        **little** endian, i.e. the qubit n°1 (0 being p) is the less significant one
        and the qubit n°(n+1) is the most significant one.
    """

    def cPH(theta: float) -> QRoutine:
        return PH(theta).ctrl()

    routine = QRoutine()

    p = routine.new_wires(1)
    w = routine.new_wires(3 * r)

    routine.protected_apply(X, p)
    for j in range(3 * r):
        routine.apply(cPH(-(2 ** (j - 2 * r)) * time), p, w[3 * r - 1 - j])
    routine.protected_apply(X, p)

    for j in range(3 * r):
        routine.apply(cPH(2 ** (j - 2 * r) * time), p, w[3 * r - 1 - j])

    return routine


@build_gate("exp_ZZFt", [int, float], arity=lambda n, t: n + 2)
def exp_ZZFt(n: int, time: float) -> QRoutine:
    r"""Perform the e^{iZ\otimes Z \otimes Ft} operator.

    The returned procedure takes n+2 qubits as parameters: ::

        |  |p>  |  |s>  |        |w>        |
        |   .   |   .   |   .   .   .   .   |
        |   0   |   1   |   2          n+1  |

    - The first qubit (called p) should be initialised with the parity count (see the
        article).
    - The second qubit (called s) should be initialised with the sign of the
        corresponding weight.
    - The n following qubits represents the weights of the Hamiltonian matrix in
        **little** endian (i.e. the qubit n°1 (0 being p) is the less significant one
        and the qubit n°(n+1) is the most significant one.
    :param n: Number of qubits used to encode an integer on :math:`\ket{w}`.
    :param time: Time of evolution for the hermitian operator ZF.
    """

    def ccPH(theta: float) -> QRoutine:
        return PH(theta).ctrl().ctrl()

    routine = QRoutine()
    p = routine.new_wires(1)
    s = routine.new_wires(1)
    w = routine.new_wires(n)

    routine.protected_apply(X, p)
    for j in range(n):
        routine.apply(ccPH(2 ** j * time), p, s, w[n - 1 - j])
    routine.protected_apply(X.dag(), p)

    routine.protected_apply(X, s)
    for j in range(n):
        routine.apply(ccPH(2 ** j * time), p, s, w[n - 1 - j])
    routine.protected_apply(X.dag(), s)

    for j in range(n):
        routine.apply(ccPH(-(2 ** j) * time), p, s, w[n - 1 - j])

    routine.protected_apply(X, s)
    routine.protected_apply(X, p)
    for j in range(n):
        routine.apply(ccPH(-(2 ** j) * time), p, s, w[n - 1 - j])
    routine.protected_apply(X.dag(), p)
    routine.protected_apply(X.dag(), s)

    return routine


@build_gate("exp_ZZF2rt", [int, float], arity=lambda r, t: 3 * r + 1)
def exp_ZZF2r_t(r: int, time: float) -> QRoutine:
    r"""Perform the e^{iZ\otimes Z\otimes F/2^{2r} t} operator.

    The returned procedure takes 3r+2 qubits as parameters: ::

        |  |p>  |  |s>  |        |w>        |
        |   .   |   .   |   .   .   .   .   |
        |   0   |   1   |   2         3*r+1 |

    - The first qubit (called p) should be initialised with the parity count (see the
        article).
    - The second qubit (called s) should be initialised with the sign of the
        corresponding weight.
    - The 3*r following qubits represents the weights of the Hamiltonian matrix in
        **little** endian (i.e. the qubit n°1 (0 being p) is the less significant one
        and the qubit n°(n+1) is the most significant one.
    """

    def ccPH(theta: float) -> QRoutine:
        return PH(theta).ctrl().ctrl()

    routine = QRoutine()
    p = routine.new_wires(1)
    s = routine.new_wires(1)
    w = routine.new_wires(3 * r)

    routine.protected_apply(X, p)
    for j in range(3 * r):
        routine.apply(ccPH((2 ** (j - 2 * r)) * time), p, s, w[3 * r - 1 - j])
    routine.protected_apply(X.dag(), p)

    routine.protected_apply(X, s)
    for j in range(3 * r):
        routine.apply(ccPH((2 ** (j - 2 * r)) * time), p, s, w[3 * r - 1 - j])
    routine.protected_apply(X.dag(), s)

    for j in range(3 * r):
        routine.apply(ccPH(-(2 ** (j - 2 * r)) * time), p, s, w[3 * r - 1 - j])

    routine.protected_apply(X, s)
    routine.protected_apply(X, p)
    for j in range(3 * r):
        routine.apply(ccPH(-(2 ** (j - 2 * r)) * time), p, s, w[3 * r - 1 - j])
    routine.protected_apply(X.dag(), p)
    routine.protected_apply(X.dag(), s)

    return routine


@build_gate("exp_ZZFt", [float], arity=2)
def exp_ZZt(time: float) -> QRoutine:
    r"""Perform the e^{iZ\otimes Zt} operator.

    The returned procedure takes 2 qubits as parameters: ::

        |  |p>  |  |s>  |
        |   .   |   .   |
        |   0   |   1   |

    - The first qubit (called p) should be initialised with the parity count (see the
        article).
    - The second qubit (called s) should be initialised with the sign (see the article).
    """

    def cRZ(theta: float) -> QRoutine:
        return RZ(theta).ctrl()

    routine = QRoutine()
    p = routine.new_wires(1)
    s = routine.new_wires(1)

    routine.protected_apply(X, p)
    routine.apply(cRZ(2 * time), p, s)
    routine.protected_apply(X.dag(), p)
    routine.apply(cRZ(-2 * time), p, s)

    return routine
